#
#  Copyright (C) 2000-2008  greg Landrum and Rational Discovery LLC
#   All Rights Reserved
#
""" Utilities for data manipulation

**FILE FORMATS:**

 - *.qdat files* contain quantized data suitable for
  feeding to learning algorithms.

  The .qdat file, written by _DecTreeGui_, is structured as follows:

   1) Any number of lines which are ignored.

   2) A line containing the string 'Variable Table'

      any number of variable definitions in the format:

      '# Variable_name [quant_bounds]'

        where '[quant_bounds]' is a list of the boundaries used for quantizing
         that variable.  If the variable is inherently integral (i.e. not
         quantized), this can be an empty list.

   3) A line beginning with '# ----' which signals the end of the variable list

   4) Any number of lines containing data points, in the format:

      'Name_of_point var1 var2 var3 .... varN'

      all variable values should be integers

   Throughout, it is assumed that varN is the result

 - *.dat files* contain the same information as .qdat files, but the variable
   values can be anything (floats, ints, strings).  **These files should
   still contain quant_bounds!**

 - *.qdat.pkl file* contain a pickled (binary) representation of
   the data read in.  They stores, in order:

    1) A python list of the variable names

    2) A python list of lists with the quantization bounds

    3) A python list of the point names

    4) A python list of lists with the data points

"""
from __future__ import print_function

import csv
import random
import re

import numpy

from rdkit.DataStructs import BitUtils
from rdkit.ML.Data import MLData
from rdkit.six import integer_types
from rdkit.six.moves import cPickle
from rdkit.utils import fileutils


def permutation(nToDo):
  res = list(range(nToDo))
  random.shuffle(res, random=random.random)
  return res


def WriteData(outFile, varNames, qBounds, examples):
  """ writes out a .qdat file

    **Arguments**

      - outFile: a file object

      - varNames: a list of variable names

      - qBounds: the list of quantization bounds (should be the same length
         as _varNames_)

      - examples: the data to be written

  """
  outFile.write('# Quantized data from DataUtils\n')
  outFile.write('# ----------\n')
  outFile.write('# Variable Table\n')
  for i in range(len(varNames)):
    outFile.write('# %s %s\n' % (varNames[i], str(qBounds[i])))
  outFile.write('# ----------\n')
  for example in examples:
    outFile.write(' '.join([str(e) for e in example]) + '\n')


def ReadVars(inFile):
  """ reads the variables and quantization bounds from a .qdat or .dat file

    **Arguments**

      - inFile: a file object

    **Returns**

      a 2-tuple containing:

        1) varNames: a list of the variable names

        2) qbounds: the list of quantization bounds for each variable

  """
  varNames = []
  qBounds = []
  fileutils.MoveToMatchingLine(inFile, 'Variable Table')
  inLine = inFile.readline()
  while inLine.find('# ----') == -1:
    splitLine = inLine[2:].split('[')
    varNames.append(splitLine[0].strip())
    qBounds.append(splitLine[1][:-2])
    inLine = inFile.readline()
  for i in range(len(qBounds)):

    if qBounds[i] != '':
      l = qBounds[i].split(',')
      qBounds[i] = []
      for item in l:
        qBounds[i].append(float(item))
    else:
      qBounds[i] = []
  return varNames, qBounds


def ReadQuantExamples(inFile):
  """ reads the examples from a .qdat file

    **Arguments**

      - inFile: a file object

    **Returns**

      a 2-tuple containing:

        1) the names of the examples

        2) a list of lists containing the examples themselves

    **Note**

      because this is reading a .qdat file, it assumed that all variable values
      are integers

  """
  expr1 = re.compile(r'^#')
  expr2 = re.compile(r'[\ ]+|[\t]+')
  examples = []
  names = []
  inLine = inFile.readline()
  while inLine:
    if expr1.search(inLine) is None:
      resArr = expr2.split(inLine)
      if len(resArr) > 1:
        examples.append([int(x) for x in resArr[1:]])
        names.append(resArr[0])
    inLine = inFile.readline()
  return names, examples


def ReadGeneralExamples(inFile):
  """ reads the examples from a .dat file

    **Arguments**

      - inFile: a file object

    **Returns**

      a 2-tuple containing:

        1) the names of the examples

        2) a list of lists containing the examples themselves

    **Note**

      - this attempts to convert variable values to ints, then floats.
        if those both fail, they are left as strings

  """
  expr1 = re.compile(r'^#')
  expr2 = re.compile(r'[\ ]+|[\t]+')
  examples = []
  names = []
  inLine = inFile.readline()
  while inLine:
    if expr1.search(inLine) is None:
      resArr = expr2.split(inLine)[:-1]
      if len(resArr) > 1:
        for i in range(1, len(resArr)):
          d = resArr[i]
          try:
            resArr[i] = int(d)
          except ValueError:
            try:
              resArr[i] = float(d)
            except ValueError:
              pass
        examples.append(resArr[1:])
        names.append(resArr[0])
    inLine = inFile.readline()
  return names, examples


def BuildQuantDataSet(fileName):
  """ builds a data set from a .qdat file

    **Arguments**

      - fileName: the name of the .qdat file

    **Returns**

      an _MLData.MLQuantDataSet_

  """
  with open(fileName, 'r') as inFile:
    varNames, qBounds = ReadVars(inFile)
    ptNames, examples = ReadQuantExamples(inFile)
  data = MLData.MLQuantDataSet(examples, qBounds=qBounds, varNames=varNames, ptNames=ptNames)
  return data


def BuildDataSet(fileName):
  """ builds a data set from a .dat file

    **Arguments**

      - fileName: the name of the .dat file

    **Returns**

      an _MLData.MLDataSet_

  """
  with open(fileName, 'r') as inFile:
    varNames, qBounds = ReadVars(inFile)
    ptNames, examples = ReadGeneralExamples(inFile)
  data = MLData.MLDataSet(examples, qBounds=qBounds, varNames=varNames, ptNames=ptNames)
  return data


def CalcNPossibleUsingMap(data, order, qBounds, nQBounds=None, silent=True):
  """ calculates the number of possible values for each variable in a data set

   **Arguments**

     - data: a list of examples

     - order: the ordering map between the variables in _data_ and _qBounds_

     - qBounds: the quantization bounds for the variables

   **Returns**

      a list with the number of possible values each variable takes on in the data set

   **Notes**

     - variables present in _qBounds_ will have their _nPossible_ number read
       from _qbounds

     - _nPossible_ for other numeric variables will be calculated

  """
  numericTypes = integer_types + (float, numpy.int64, numpy.int32, numpy.int16)

  if not silent:
    print('order:', order, len(order))
    print('qB:', qBounds)
    # print('nQB:',nQBounds, len(nQBounds))
  assert (qBounds and len(order) == len(qBounds)) or (nQBounds and len(order) == len(nQBounds)), \
      'order/qBounds mismatch'
  nVars = len(order)
  nPossible = [-1] * nVars
  cols = list(range(nVars))
  for i in range(nVars):
    if nQBounds and nQBounds[i] != 0:
      nPossible[i] = -1
      cols.remove(i)
    elif len(qBounds[i]) > 0:
      nPossible[i] = len(qBounds[i])
      cols.remove(i)

  nPts = len(data)
  for i in range(nPts):
    for col in cols[:]:
      d = data[i][order[col]]
      if type(d) in numericTypes:
        if int(d) == d:
          nPossible[col] = max(int(d), nPossible[col])
        else:
          nPossible[col] = -1
          cols.remove(col)
      else:
        if not silent:
          print('bye bye col %d: %s' % (col, repr(d)))
        nPossible[col] = -1
        cols.remove(col)
  return [int(x) + 1 for x in nPossible]


def WritePickledData(outName, data):
  """ writes either a .qdat.pkl or a .dat.pkl file

    **Arguments**

      - outName: the name of the file to be used

      - data: either an _MLData.MLDataSet_ or an _MLData.MLQuantDataSet_

  """
  varNames = data.GetVarNames()
  qBounds = data.GetQuantBounds()
  ptNames = data.GetPtNames()
  examples = data.GetAllData()
  with open(outName, 'wb+') as outFile:
    cPickle.dump(varNames, outFile)
    cPickle.dump(qBounds, outFile)
    cPickle.dump(ptNames, outFile)
    cPickle.dump(examples, outFile)


def TakeEnsemble(vect, ensembleIds, isDataVect=False):
  """

  >>> v = [10,20,30,40,50]
  >>> TakeEnsemble(v,(1,2,3))
  [20, 30, 40]
  >>> v = ['foo',10,20,30,40,50,1]
  >>> TakeEnsemble(v,(1,2,3),isDataVect=True)
  ['foo', 20, 30, 40, 1]

  """
  if isDataVect:
    ensembleIds = [x + 1 for x in ensembleIds]
    vect = [vect[0]] + [vect[x] for x in ensembleIds] + [vect[-1]]
  else:
    vect = [vect[x] for x in ensembleIds]
  return vect


def DBToData(dbName, tableName, user='sysdba', password='masterkey', dupCol=-1, what='*', where='',
             join='', pickleCol=-1, pickleClass=None, ensembleIds=None):
  """ constructs  an _MLData.MLDataSet_ from a database

    **Arguments**

      - dbName: the name of the database to be opened

      - tableName: the table name containing the data in the database

      - user: the user name to be used to connect to the database

      - password: the password to be used to connect to the database

      - dupCol: if nonzero specifies which column should be used to recognize
        duplicates.

    **Returns**

       an _MLData.MLDataSet_

    **Notes**

      - this uses Dbase.DataUtils functionality

  """
  from rdkit.Dbase.DbConnection import DbConnect
  conn = DbConnect(dbName, tableName, user, password)
  res = conn.GetData(fields=what, where=where, join=join, removeDups=dupCol, forceList=1)
  nPts = len(res)
  vals = [None] * nPts
  ptNames = [None] * nPts
  classWorks = True
  for i in range(nPts):
    tmp = list(res[i])
    ptNames[i] = tmp.pop(0)
    if pickleCol >= 0:
      if not pickleClass or not classWorks:
        tmp[pickleCol] = cPickle.loads(str(tmp[pickleCol]))
      else:
        try:
          tmp[pickleCol] = pickleClass(str(tmp[pickleCol]))
        except Exception:
          tmp[pickleCol] = cPickle.loads(str(tmp[pickleCol]))
          classWorks = False
      if ensembleIds:
        tmp[pickleCol] = BitUtils.ConstructEnsembleBV(tmp[pickleCol], ensembleIds)
    else:
      if ensembleIds:
        tmp = TakeEnsemble(tmp, ensembleIds, isDataVect=True)
    vals[i] = tmp
  varNames = conn.GetColumnNames(join=join, what=what)
  data = MLData.MLDataSet(vals, varNames=varNames, ptNames=ptNames)
  return data


def TextToData(reader, ignoreCols=[], onlyCols=None):
  """ constructs  an _MLData.MLDataSet_ from a bunch of text
#DOC
    **Arguments**
      - reader needs to be iterable and return lists of elements
        (like a csv.reader)

    **Returns**

       an _MLData.MLDataSet_

  """

  varNames = next(reader)
  if not onlyCols:
    keepCols = []
    for i, name in enumerate(varNames):
      if name not in ignoreCols:
        keepCols.append(i)
  else:
    keepCols = [-1] * len(onlyCols)
    for i, name in enumerate(varNames):
      if name in onlyCols:
        keepCols[onlyCols.index(name)] = i

  nCols = len(varNames)
  varNames = tuple([varNames[x] for x in keepCols])
  nVars = len(varNames)
  vals = []
  ptNames = []
  for splitLine in reader:
    if len(splitLine):
      if len(splitLine) != nCols:
        raise ValueError('unequal line lengths')
      tmp = [splitLine[x] for x in keepCols]
      ptNames.append(tmp[0])
      pt = [None] * (nVars - 1)
      for j in range(nVars - 1):
        try:
          val = int(tmp[j + 1])
        except ValueError:
          try:
            val = float(tmp[j + 1])
          except ValueError:
            val = str(tmp[j + 1])
        pt[j] = val
      vals.append(pt)
  data = MLData.MLDataSet(vals, varNames=varNames, ptNames=ptNames)
  return data


def TextFileToData(fName, onlyCols=None):
  """
  #DOC

  """
  ext = fName.split('.')[-1]
  with open(fName, 'r') as inF:
    if ext.upper() == 'CSV':
      #  CSV module distributed with python2.3 and later
      splitter = csv.reader(inF)
    else:
      splitter = csv.reader(inF, delimiter='\t')
    res = TextToData(splitter, onlyCols=onlyCols)
  return res


def InitRandomNumbers(seed):
  """ Seeds the random number generators

    **Arguments**

      - seed: a 2-tuple containing integers to be used as the random number seeds

    **Notes**

      this seeds both the RDRandom generator and the one in the standard
      Python _random_ module

  """
  from rdkit import RDRandom
  RDRandom.seed(seed[0])
  random.seed(seed[0])


def FilterData(inData, val, frac, col=-1, indicesToUse=None, indicesOnly=0):
  """
#DOC
  """
  if frac < 0 or frac > 1:
    raise ValueError('filter fraction out of bounds')
  try:
    inData[0][col]
  except IndexError:
    raise ValueError('target column index out of range')

  # convert the input data to a list and sort them
  if indicesToUse:
    tmp = [inData[x] for x in indicesToUse]
  else:
    tmp = list(inData)
  nOrig = len(tmp)
  sortOrder = list(range(nOrig))
  sortOrder.sort(key=lambda x: tmp[x][col])
  tmp = [tmp[x] for x in sortOrder]

  # find the start of the entries with value val
  start = 0
  while start < nOrig and tmp[start][col] != val:
    start += 1
  if start >= nOrig:
    raise ValueError('target value (%d) not found in data' % (val))

  # find the end of the entries with value val
  finish = start + 1
  while finish < nOrig and tmp[finish][col] == val:
    finish += 1

  # how many entries have the target value?
  nWithVal = finish - start

  # how many don't?
  nOthers = len(tmp) - nWithVal

  currFrac = float(nWithVal) / nOrig
  if currFrac < frac:
    #
    # We're going to keep most of (all) the points with the target value,
    #  We need to figure out how many of the other points we'll
    #  toss out
    #
    nTgtFinal = nWithVal
    nFinal = int(round(nWithVal / frac))
    nOthersFinal = nFinal - nTgtFinal

    #
    # We may need to reduce the number of targets to keep
    #  because it may make it impossible to hit exactly the
    #  fraction we're trying for.  Take care of that now
    #
    while float(nTgtFinal) / nFinal > frac:
      nTgtFinal -= 1
      nFinal -= 1

  else:
    #
    # There are too many points with the target value,
    #  we'll keep most of (all) the other points and toss a random
    #  selection of the target value points
    #
    nOthersFinal = nOthers
    nFinal = int(round(nOthers / (1 - frac)))
    nTgtFinal = nFinal - nOthersFinal

    #
    # We may need to reduce the number of others to keep
    #  because it may make it impossible to hit exactly the
    #  fraction we're trying for.  Take care of that now
    #
    while float(nTgtFinal) / nFinal < frac:
      nOthersFinal -= 1
      nFinal -= 1

  others = list(range(start)) + list(range(finish, nOrig))
  othersTake = permutation(nOthers)
  others = [others[x] for x in othersTake[:nOthersFinal]]

  targets = list(range(start, finish))
  targetsTake = permutation(nWithVal)
  targets = [targets[x] for x in targetsTake[:nTgtFinal]]

  # these are all the indices we'll be keeping
  indicesToKeep = targets + others

  res = []
  rej = []
  # now pull the points, but in random order
  if not indicesOnly:
    for i in permutation(nOrig):
      if i in indicesToKeep:
        res.append(tmp[i])
      else:
        rej.append(tmp[i])
  else:
    # EFF: this is slower than it needs to be
    for i in permutation(nOrig):
      if not indicesToUse:
        idx = sortOrder[i]
      else:
        idx = indicesToUse[sortOrder[i]]
      if i in indicesToKeep:
        res.append(idx)
      else:
        rej.append(idx)
  return res, rej


def CountResults(inData, col=-1, bounds=None):
  """ #DOC
  """
  counts = {}
  for p in inData:
    if not bounds:
      r = p[col]
    else:
      act = p[col]
      bound = 0
      placed = 0
      while not placed and bound < len(bounds):
        if act < bounds[bound]:
          r = bound
          placed = 1
        else:
          bound += 1
      if not placed:
        r = bound

    counts[r] = counts.get(r, 0) + 1
  return counts


def RandomizeActivities(dataSet, shuffle=0, runDetails=None):
  """ randomizes the activity values of a dataset

    **Arguments**

      - dataSet: a _ML.Data.MLQuantDataSet_, the activities here will be randomized

      - shuffle: an optional toggle. If this is set, the activity values
        will be shuffled (so the number in each class remains constant)

      - runDetails: an optional CompositeRun object

    **Note**

      - _examples_ are randomized in place


  """
  nPts = dataSet.GetNPts()
  if shuffle:
    if runDetails:
      runDetails.shuffled = 1
    acts = dataSet.GetResults()[:]
    # While the random argument is the default, removing it will cause the shuffle
    # tests in UnitTestScreenComposite to fail.
    random.shuffle(acts, random=random.random)
  else:  # This part of the code isn't working as examples is not defined
    if runDetails:
      runDetails.randomized = 1
    nPossible = dataSet.GetNPossibleVals()[-1]
    acts = [random.randint(0, nPossible) for _ in len(examples)]
  for i in range(nPts):
    tmp = dataSet[i]
    tmp[-1] = acts[i]
    dataSet[i] = tmp


# ------------------------------------
#
#  doctest boilerplate
#
def _runDoctests(verbose=None):  # pragma: nocover
  import sys
  import doctest
  failed, _ = doctest.testmod(optionflags=doctest.ELLIPSIS, verbose=verbose)
  sys.exit(failed)


if __name__ == '__main__':  # pragma: nocover
  _runDoctests()
